{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### How to train our model\n",
    "In this notebook, we present our jobs on air quality prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\ProgramData\\Anaconda3\\lib\\site-packages\\h5py\\__init__.py:34: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from matplotlib import pyplot as plt\n",
    "import os\n",
    "import json\n",
    "import math\n",
    "import pickle\n",
    "from utils import load_data\n",
    "from utils import load_global_inputs\n",
    "from utils import basic_hyperparams\n",
    "from GeoMAN import GeoMAN\n",
    "from utils import get_batch_feed_dict\n",
    "from utils import shuffle_data\n",
    "from utils import get_valid_batch_feed_dict\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(2017)\n",
    "\n",
    "# use specific gpu\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "tf_config = tf.ConfigProto()\n",
    "tf_config.gpu_options.allow_growth = True\n",
    "session = tf.Session(config=tf_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[('dropout_rate', 0.3), ('ext_flag', True), ('gc_rate', 2.5), ('gpu_id', '0'), ('lambda_l2_reg', 0.001), ('learning_rate', 0.001), ('n_external_input', 83), ('n_hidden_decoder', 128), ('n_hidden_encoder', 128), ('n_input_decoder', 1), ('n_input_encoder', 19), ('n_output_decoder', 1), ('n_sensors', 35), ('n_stacked_layers', 2), ('n_steps_decoder', 6), ('n_steps_encoder', 12), ('s_attn_flag', 2)]\n"
     ]
    }
   ],
   "source": [
    "# load hyperparameters\n",
    "hps = basic_hyperparams()\n",
    "hps_dict = json.load(open('./hparam_files/AirQualityGeoMan.json', 'r'))\n",
    "hps.override_from_dict(hps_dict)\n",
    "print(hps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train samples: 100\n",
      "eval samples: 10\n"
     ]
    }
   ],
   "source": [
    "# read data from different sets\n",
    "input_path = './sample_data/'\n",
    "training_data = load_data(\n",
    "    input_path, 'train', hps.n_steps_encoder, hps.n_steps_decoder)\n",
    "valid_data = load_data(\n",
    "    input_path, 'eval', hps.n_steps_encoder, hps.n_steps_decoder)\n",
    "global_inputs, global_attn_states = load_global_inputs(\n",
    "    input_path, hps.n_steps_encoder, hps.n_steps_decoder)\n",
    "# print dataset info\n",
    "num_train = len(training_data[0])\n",
    "num_valid = len(valid_data[0])\n",
    "print('train samples: {0}'.format(num_train))\n",
    "print('eval samples: {0}'.format(num_valid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Summary name Weights_out:0 is illegal; using Weights_out_0 instead.\n",
      "INFO:tensorflow:Summary name Biases_out:0 is illegal; using Biases_out_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Encoder/spatial_attn/local_spatial_attn/AttnUl:0 is illegal; using GeoMAN/Encoder/spatial_attn/local_spatial_attn/AttnUl_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Encoder/spatial_attn/local_spatial_attn/AttnVl:0 is illegal; using GeoMAN/Encoder/spatial_attn/local_spatial_attn/AttnVl_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Encoder/spatial_attn/global_spatial_attn/AttnW_and_u:0 is illegal; using GeoMAN/Encoder/spatial_attn/global_spatial_attn/AttnW_and_u_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Encoder/spatial_attn/global_spatial_attn/AttnVg:0 is illegal; using GeoMAN/Encoder/spatial_attn/global_spatial_attn/AttnVg_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Encoder/spatial_attn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0 is illegal; using GeoMAN/Encoder/spatial_attn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Encoder/spatial_attn/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0 is illegal; using GeoMAN/Encoder/spatial_attn/multi_rnn_cell/cell_0/basic_lstm_cell/bias_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Encoder/spatial_attn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0 is illegal; using GeoMAN/Encoder/spatial_attn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Encoder/spatial_attn/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0 is illegal; using GeoMAN/Encoder/spatial_attn/multi_rnn_cell/cell_1/basic_lstm_cell/bias_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Encoder/spatial_attn/local_spatial_attn/AttnWl/kernel:0 is illegal; using GeoMAN/Encoder/spatial_attn/local_spatial_attn/AttnWl/kernel_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Encoder/spatial_attn/local_spatial_attn/AttnWl/bias:0 is illegal; using GeoMAN/Encoder/spatial_attn/local_spatial_attn/AttnWl/bias_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Encoder/spatial_attn/global_spatial_attn/AttnWg/kernel:0 is illegal; using GeoMAN/Encoder/spatial_attn/global_spatial_attn/AttnWg/kernel_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Encoder/spatial_attn/global_spatial_attn/AttnWg/bias:0 is illegal; using GeoMAN/Encoder/spatial_attn/global_spatial_attn/AttnWg/bias_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Decoder/temporal_attn/Attn_Wd:0 is illegal; using GeoMAN/Decoder/temporal_attn/Attn_Wd_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Decoder/temporal_attn/Attn_v:0 is illegal; using GeoMAN/Decoder/temporal_attn/Attn_v_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Decoder/temporal_attn/kernel:0 is illegal; using GeoMAN/Decoder/temporal_attn/kernel_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Decoder/temporal_attn/bias:0 is illegal; using GeoMAN/Decoder/temporal_attn/bias_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Decoder/temporal_attn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0 is illegal; using GeoMAN/Decoder/temporal_attn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Decoder/temporal_attn/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0 is illegal; using GeoMAN/Decoder/temporal_attn/multi_rnn_cell/cell_0/basic_lstm_cell/bias_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Decoder/temporal_attn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0 is illegal; using GeoMAN/Decoder/temporal_attn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Decoder/temporal_attn/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0 is illegal; using GeoMAN/Decoder/temporal_attn/multi_rnn_cell/cell_1/basic_lstm_cell/bias_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Decoder/temporal_attn/Attn_Wpd/kernel:0 is illegal; using GeoMAN/Decoder/temporal_attn/Attn_Wpd/kernel_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Decoder/temporal_attn/Attn_Wpd/bias:0 is illegal; using GeoMAN/Decoder/temporal_attn/Attn_Wpd/bias_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Decoder/temporal_attn/AttnOutputProjection/kernel:0 is illegal; using GeoMAN/Decoder/temporal_attn/AttnOutputProjection/kernel_0 instead.\n",
      "INFO:tensorflow:Summary name GeoMAN/Decoder/temporal_attn/AttnOutputProjection/bias:0 is illegal; using GeoMAN/Decoder/temporal_attn/AttnOutputProjection/bias_0 instead.\n",
      "<tf.Variable 'Weights_out:0' shape=(128, 1) dtype=float32_ref>\n",
      "<tf.Variable 'Biases_out:0' shape=(1,) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Encoder/spatial_attn/local_spatial_attn/AttnUl:0' shape=(1, 1, 12, 12) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Encoder/spatial_attn/local_spatial_attn/AttnVl:0' shape=(12,) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Encoder/spatial_attn/global_spatial_attn/AttnW_and_u:0' shape=(1, 19, 12, 12) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Encoder/spatial_attn/global_spatial_attn/AttnVg:0' shape=(12,) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Encoder/spatial_attn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0' shape=(182, 512) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Encoder/spatial_attn/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0' shape=(512,) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Encoder/spatial_attn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0' shape=(256, 512) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Encoder/spatial_attn/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0' shape=(512,) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Encoder/spatial_attn/local_spatial_attn/AttnWl/kernel:0' shape=(512, 12) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Encoder/spatial_attn/local_spatial_attn/AttnWl/bias:0' shape=(12,) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Encoder/spatial_attn/global_spatial_attn/AttnWg/kernel:0' shape=(512, 12) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Encoder/spatial_attn/global_spatial_attn/AttnWg/bias:0' shape=(12,) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Decoder/temporal_attn/Attn_Wd:0' shape=(1, 1, 128, 128) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Decoder/temporal_attn/Attn_v:0' shape=(128,) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Decoder/temporal_attn/kernel:0' shape=(212, 1) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Decoder/temporal_attn/bias:0' shape=(1,) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Decoder/temporal_attn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0' shape=(129, 512) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Decoder/temporal_attn/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0' shape=(512,) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Decoder/temporal_attn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0' shape=(256, 512) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Decoder/temporal_attn/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0' shape=(512,) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Decoder/temporal_attn/Attn_Wpd/kernel:0' shape=(512, 128) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Decoder/temporal_attn/Attn_Wpd/bias:0' shape=(128,) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Decoder/temporal_attn/AttnOutputProjection/kernel:0' shape=(256, 128) dtype=float32_ref>\n",
      "<tf.Variable 'GeoMAN/Decoder/temporal_attn/AttnOutputProjection/bias:0' shape=(128,) dtype=float32_ref>\n",
      "[<tf.Operation 'inputs/local_inputs' type=Placeholder>, <tf.Operation 'inputs/global_inputs' type=Placeholder>, <tf.Operation 'inputs/external_inputs' type=Placeholder>, <tf.Operation 'inputs/local_attn_states' type=Placeholder>, <tf.Operation 'inputs/global_attn_states' type=Placeholder>, <tf.Operation 'groundtruth/labels' type=Placeholder>]\n",
      "total parameters: 554054\n"
     ]
    }
   ],
   "source": [
    "# model construction\n",
    "tf.reset_default_graph()\n",
    "model = GeoMAN(hps)\n",
    "# print trainable params\n",
    "for i in tf.trainable_variables():\n",
    "    print(i)\n",
    "# print all placeholders\n",
    "phs = [x for x in tf.get_default_graph().get_operations()\n",
    "       if x.type == \"Placeholder\"]\n",
    "print(phs)\n",
    "# count the parameters in our model\n",
    "total_parameters = 0\n",
    "for variable in tf.trainable_variables():\n",
    "    # shape is an array of tf.Dimension\n",
    "    shape = variable.get_shape()\n",
    "    # print(shape)\n",
    "    # print(len(shape))\n",
    "    variable_parameters = 1\n",
    "    for dim in shape:\n",
    "        # print(dim)\n",
    "        variable_parameters *= dim.value\n",
    "    # print(variable_parameters)\n",
    "    total_parameters += variable_parameters\n",
    "print('total parameters: {}'.format(total_parameters))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# path for log saving\n",
    "if hps.ext_flag:\n",
    "    if hps.s_attn_flag == 0:\n",
    "        model_name = 'GeoMANng'\n",
    "    elif hps.s_attn_flag == 1:\n",
    "        model_name = 'GeoMANnl'\n",
    "    else:\n",
    "        model_name = 'GeoMAN'\n",
    "else:\n",
    "    model_name = 'GeoMANne'\n",
    "logdir = './logs/{}-{}-{}-{}-{}-{:.2f}-{:.3f}/'.format(model_name,\n",
    "                                                       hps.n_steps_encoder,\n",
    "                                                       hps.n_steps_decoder,\n",
    "                                                       hps.n_stacked_layers,\n",
    "                                                       hps.n_hidden_encoder,\n",
    "                                                       hps.dropout_rate,\n",
    "                                                       hps.lambda_l2_reg)\n",
    "model_dir = logdir + 'saved_models/'\n",
    "if not os.path.exists(logdir):\n",
    "    os.mkdir(logdir)\n",
    "if not os.path.exists(model_dir):\n",
    "    os.mkdir(model_dir)\n",
    "results_dir = logdir + 'results/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# train params\n",
    "total_epoch = 15\n",
    "batch_size = 16 # for sample data\n",
    "display_iter = 1\n",
    "save_log_iter = 10\n",
    "n_split_valid = 2  # times of splitting validation set\n",
    "valid_losses = [np.inf]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------epoch 0-----------\n",
      "iter 1\tvalid_loss = 0.558732\tmodel saved!!\n",
      "iter 2\tvalid_loss = 0.364431\tmodel saved!!\n",
      "iter 3\tvalid_loss = 0.651559\t\n",
      "iter 4\tvalid_loss = 0.617913\t\n",
      "iter 5\tvalid_loss = 0.197884\tmodel saved!!\n"
     ]
    }
   ],
   "source": [
    "# training process\n",
    "with tf.Session() as sess:\n",
    "    saver = tf.train.Saver()\n",
    "\n",
    "    # initialize\n",
    "    model.init(sess)\n",
    "    iter = 0\n",
    "    summary_writer = tf.summary.FileWriter(logdir)\n",
    "\n",
    "    for i in range(total_epoch):\n",
    "        print('----------epoch {}-----------'.format(i))\n",
    "        training_data = shuffle_data(training_data)\n",
    "\n",
    "        for j in range(0, num_train, batch_size):\n",
    "            iter += 1\n",
    "            feed_dict = get_batch_feed_dict(\n",
    "                model, j, batch_size, training_data, global_inputs, global_attn_states)\n",
    "#             print(feed_dict)\n",
    "            _, merged_summary = sess.run(\n",
    "                [model.phs['train_op'], model.phs['summary']], feed_dict)\n",
    "            # summary_writer.add_summary(merged_summary, iter)\n",
    "            if iter % save_log_iter == 0:\n",
    "                summary_writer.add_summary(merged_summary, iter)\n",
    "            if iter % display_iter == 0:\n",
    "                valid_loss = 0\n",
    "                valid_indexes = np.int64(\n",
    "                    np.linspace(0, num_valid, n_split_valid))\n",
    "                for k in range(n_split_valid - 1):\n",
    "                    feed_dict = get_valid_batch_feed_dict(\n",
    "                        model, valid_indexes, k, valid_data, global_inputs, global_attn_states)\n",
    "                    batch_loss = sess.run(model.phs['loss'], feed_dict)\n",
    "                    valid_loss += batch_loss\n",
    "                valid_loss /= n_split_valid - 1\n",
    "                valid_losses.append(valid_loss)\n",
    "                valid_loss_sum = tf.Summary(\n",
    "                    value=[tf.Summary.Value(tag=\"valid_loss\", simple_value=valid_loss)])\n",
    "                summary_writer.add_summary(valid_loss_sum, iter)\n",
    "\n",
    "                if valid_loss < min(valid_losses[:-1]):\n",
    "                    print('iter {}\\tvalid_loss = {:.6f}\\tmodel saved!!'.format(\n",
    "                        iter, valid_loss))\n",
    "                    saver.save(sess, model_dir +\n",
    "                               'model_{}.ckpt'.format(iter))\n",
    "                    saver.save(sess, model_dir + 'final_model.ckpt')\n",
    "                else:\n",
    "                    print('iter {}\\tvalid_loss = {:.6f}\\t'.format(\n",
    "                        iter, valid_loss))\n",
    "\n",
    "print('stop training !!!')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
